from functools import wraps
import os
from multiprocessing.dummy import Pool as ThreadPool
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np

CAPACITY = 200
ITEM_VALUE = 20000


class Item():
    def __init__(self, weight, value, *args, **kwargs):
        self.weight = weight
        self.value = value


n = 20
items = [Item(np.random.randint(1, 100), np.random.randint(1, ITEM_VALUE))  # 这里保证每个item都是合法的
         for _ in range(n)]

SUM_ITEM_VALUE = sum([item.value for item in items])
BLANK = 15 * n / 2
theta = np.linspace(0, 2*np.pi, n)
x = np.cos(theta)*BLANK
y = np.sin(theta)*BLANK
COORDINATE = np.asarray(list(zip(x, y)))

plt.ion()
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401 unused import
# https://matplotlib.org/api/_as_gen/mpl_toolkits.mplot3d.axes3d.Axes3D.html?highlight=bar3d#mpl_toolkits.mplot3d.axes3d.Axes3D.bar3d
# bar3d(self, x, y, z, dx, dy, dz, color=None, zsort='average',  *args, **kwargs)
plt.style.use('fivethirtyeight')
fig = plt.figure()
ax = fig.add_subplot(121, projection='3d')
ax2 = fig.add_subplot(122)
dp_value = "?"
dp_value_history = []
ant_value = "?"
ant_value_history = []
dp_ans = []
ant_ans = []
ant_max_value = 0


def plt_reset(func):
    @wraps(func)
    def warp(*args):
        ax.cla()
        ax2.cla()
        ax.axis('off')  # 去掉坐标轴
        ax.set_xlim(-BLANK, BLANK)
        ax.set_ylim(-BLANK, BLANK)
        ax.set_zlim(0, ITEM_VALUE)
        plt.title(
            f"dp_ans->{dp_value} {dp_ans}\n ant_ans->{ant_value} {ant_ans}\n history_max_value{ant_max_value} error:{round((1-ant_max_value/dp_value)*100,3)}%")
        if dp_value != "?":
            dp_value_history.append(dp_value)
            ax2.plot(list(range(len(dp_value_history))),
                     dp_value_history, color="red")
        if ant_value != "?":
            ant_value_history.append(ant_value)
            ax2.plot(list(range(len(ant_value_history))),
                     ant_value_history, color="gold")
            assert dp_value >= ant_value
        # if len(ant_value_history) > 10:
        #     del dp_value_history[0]
        #     del ant_value_history[0]
        return func(*args)
    return warp


'''
dp 解法
'''

dp = np.zeros([n+1, CAPACITY+1])

for idx, item in enumerate(items):
    for j in range(CAPACITY+1):
        i = idx + 1
        if j < item.weight:
            dp[i][j] = dp[i-1][j]
        else:
            dp[i][j] = max(dp[i-1][j], dp[i-1][int(j-item.weight)]+item.value)

print(dp[n][CAPACITY])
j = CAPACITY
check = 0
for i in range(n, 0, -1):
    if dp[i-1][j] == dp[i][j]:
        pass
    else:
        idx = i-1
        dp_ans = [idx] + dp_ans
        j -= items[idx].weight
        check += items[idx].value
print(dp_ans)
dp_value = dp[n][CAPACITY]
assert check == dp_value

@plt_reset
def show_knapsack_select(dp_ans):
    for idx, item in enumerate(items):
        x, y, z = *COORDINATE[idx], 0
        dx = dy = np.sqrt(item.weight)
        dz = item.value
        if idx in dp_ans:
            color = "red"
        else:
            color = "blue"
        ax.bar3d(x, y, z, dx, dy, dz,  color=color)
        ax.text(1.5*x, 1.5*y, z, idx, fontsize=8)
        plt.pause(0.1)

plt.pause(5)
show_knapsack_select(dp_ans)


class Ant:
    def __init__(self, *args, **kwargs):
        self.weight = 0
        self.value = 0
        self.tabu = []  # 用于记录选择


const = 1
tau = np.ones([n]) * const

eta = np.ones([n])  # 先验概率
for idx in range(n):
    eta[idx] = items[idx].value / items[idx].weight


def Pk_ij(allowed, alpha=5, beta=1):
    '''
    状态转移公式
    ===
    return 转移过去的点v
    '''
    top = np.power(tau[allowed], alpha) * \
        np.power(eta[allowed], beta)
    bottom = np.sum(np.power(tau[allowed], alpha)
                    * np.power(eta[allowed], beta))
    v = allowed[np.argmax(top/bottom)]
    return v


N_cmax = 1000
m = 100
Q = 1  # 蚁周模型更新的信息素大小 = Q * ant[k].value / max
ants = [Ant() for i in range(m)]


@plt_reset
def show_ant_system(ant_k):
    for idx, item in enumerate(items):
        x, y, z = *COORDINATE[idx], 0
        dx = dy = np.sqrt(item.weight)
        dz = item.value
        if idx in ant_k.tabu:
            color = "red"
        else:
            color = (max(0, 1-tau[idx]), max(0, 1-tau[idx]), 1)
        ax.bar3d(x, y, z, dx, dy, dz,  color=color)
        ax.text(1.5*x, 1.5*y, z, idx, fontsize=8)
    plt.pause(0.1)


weight_check_array = np.asarray([item.weight for item in items])

for N_c in (range(N_cmax)):
    for ant_k in ants:
        idx = np.random.choice(len(items))
        ant_k.weight = items[idx].weight
        ant_k.value = items[idx].value
        ant_k.tabu = [idx]
        for i in range(n-1):
            allowed = set(list(np.where(weight_check_array <=
                                        CAPACITY-ant_k.weight)[0])) - set(ant_k.tabu)
            allowed = list(allowed)
            if allowed:
                v = Pk_ij(allowed)
                ant_k.weight += items[v].weight
                ant_k.value += items[v].value
                ant_k.tabu.append(v)

                ant_max_value = max(ant_max_value, ant_k.value)
            else:
                break
    '''
    蚁周模型更新信息素
    '''
    def update_tau(p=0.1):
        global tau
        global ant_max_value
        #     tau_ij(t+n)的n已经可以忽略
        # 公式6.5
        delta_tau = np.zeros([n])
        for ant_k in ants:
            for i in ant_k.tabu:
                delta_tau[i] += Q*ant_k.value/ant_max_value/m
        # 公式6.4
        tau = (1-p)*tau+delta_tau
        print(np.max(tau))

    update_tau()
    ant_value = ants[0].value
    ant_ans = sorted(ants[0].tabu)
    show_ant_system(ants[0])
